package cn.zifangsky.hashtable.lru;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiConsumer;

/**
 * 基于“最近最少使用算法（LRU，也就是'The Least Recently Used'）”的缓存队列实现
 *
 * @author zifangsky
 * @date 2020/6/25
 * @since 1.0.0
 */
public class LRUCache<K, V> {
    /**
     * 默认的缓存大小
     */
    public static final int DEFAULT_INITIAL_CAPACITY = 16;

    /**
     * 单个节点信息
     */
    public static class Node<K, V> {
        /**
         * key
         */
        final K key;
        /**
         * value
         */
        V value;

        /**
         * 过期时间戳
         */
        long expired;

        /**
         * 上一个节点
         */
        Node<K, V> previous;

        /**
         * 下一个节点
         */
        Node<K, V> next;

        public Node(K key, V value) {
            this(key, value, -1);
        }

        public Node(K key, V value, long expired) {
            this(key, value, expired, null, null);
        }

        public Node(K key, V value, long expired, Node<K, V> previous, Node<K,V> next) {
            this.key = key;
            this.value = value;
            this.expired = expired;
            this.previous = previous;
            this.next = next;
        }

        @Override
        public final String toString() {
            return key + "=" + value;
        }

        @Override
        public final boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            Node<?, ?> node = (Node<?, ?>) o;
            return this.equals(key, node.key) &&
                    this.equals(value, node.value);
        }

        /**
         * 取 key 和 value 的 hashCode 的异或
         */
        @Override
        public final int hashCode() {
            return this.hashCode(key) ^ this.hashCode(value);
        }

        public int hashCode(Object o) {
            return o != null ? o.hashCode() : 0;
        }

        public boolean equals(Object a, Object b) {
            return (a == b) || (a != null && a.equals(b));
        }
    }

    /**
     * 定义一个存储数据的链表
     */
    public static class LinkedNodeList<K, V> {
        /**
         * 头节点
         */
        Node<K, V> head;
        /**
         * 尾节点
         */
        Node<K, V> tail;

        /**
         * list的大小
         */
        int size;

        public LinkedNodeList() {
            this.head = null;
            this.tail = null;
            this.size = 0;
        }

        /**
         * 新增节点
         * @param newNode 新节点
         */
        protected void add(Node<K, V> newNode){
            if(newNode == null){
                return;
            }

            if(this.tail == null){
                this.tail = newNode;
                this.head = this.tail;
            }else{
                this.tail.next = newNode;
                newNode.previous = this.tail;
                this.tail = newNode;
            }

            this.size++;
        }

        /**
         * 移除头结点
         */
        protected Node<K, V> removeHead(){
            if(this.head == null){
                return null;
            }

            Node<K, V> result = this.head;
            if(this.head == this.tail){
                this.head = null;
                this.tail = null;
            }else{
                this.head = this.head.next;
                this.head.previous = null;
            }

            this.size--;
            return result;
        }

        /**
         * 移除尾结点
         */
        protected Node<K, V> removeTail(){
            if(this.tail == null){
                return null;
            }

            Node<K, V> result = this.tail;
            if(this.head == this.tail){
                this.head = null;
                this.tail = null;
            }else{
                this.tail = this.tail.previous;
                this.tail.next = null;
            }

            this.size--;
            return result;
        }

        /**
         * 移除某个结点
         */
        protected void remove(K key){
            if(key == null){
                return;
            }

            Node<K, V> temp = this.head;
            while (temp != null){
                if(key.equals(temp.key)){
                    this.remove(temp);
                    break;
                }
                temp = temp.next;
            }
        }

        /**
         * 移除指定节点
         */
        protected void remove(Node<K, V> temp){
            if(temp == this.head){
                this.removeHead();
            }else if(temp == this.tail){
                this.removeTail();
            }else{
                if(temp.previous != null) {
                    temp.previous.next = temp.next;
                }
                if(temp.next != null) {
                    temp.next.previous = temp.previous;
                }
                this.size--;
            }
        }

        /**
         * 移除过期节点
         */
        protected LinkedList<K> removeExpiredNodes(){
            //所有已经过期的KEY
            LinkedList<K> expiredKeyList = new LinkedList<>();
            //当前精确到毫秒的时间戳
            long current = System.currentTimeMillis();

            Node<K, V> temp = this.head;
            while (temp != null){
                //如果当前节点已经过期，则移除，并继续检查下一个节点
                if(temp.expired > 0 && temp.expired <= current) {
                    expiredKeyList.add(temp.key);
                    this.remove(temp);
                }
                temp = temp.next;
            }

            return expiredKeyList;
        }

        /**
         * 将访问结点移动到链表末尾
         */
        protected void moveToTail(Node<K, V> node){
            if(node == null || this.tail == node){
                return;
            }

            if(this.head == node){
                this.head = node.next;
                this.head.previous = null;
            }else{
                node.previous.next = node.next;
                node.next.previous = node.previous;
            }

            this.tail.next = node;
            node.previous = this.tail;
            node.next = null;
            this.tail = node;
        }

    }


    /* ---------------- Fields -------------- */
    /**
     * 锁
     */
    private static ReentrantLock lock = new ReentrantLock();

    /**
     * 检查失效键值对的定时任务
     */
    private ScheduledExecutorService checkScheduledExecutor;

    /**
     * 缓存容量
     */
    private final int capacity;

    /**
     * 先进先出的链表
     */
    private LinkedNodeList<K, V> nodeList;

    /**
     * 定义一个Map，为了方便取数据
     */
    private Map<K, Node<K, V>> keyNodeMap;


    public LRUCache() {
        this(DEFAULT_INITIAL_CAPACITY, true);
    }

    /**
     * @param capacity CACHE容量
     */
    public LRUCache(int capacity) {
        this(capacity, true);
    }

    /**
     * @param capacity CACHE容量
     * @param autoCheckExpiredKey 是否自动检查失效的键值对
     */
    public LRUCache(int capacity, boolean autoCheckExpiredKey) {
        this.capacity = capacity;
        this.nodeList = new LinkedNodeList<>();
        this.keyNodeMap = new HashMap<>(capacity);

        if(autoCheckExpiredKey){
            this.initCheckScheduledExecutor();
        }
    }

    /**
     * 缓存的键值对数量
     */
    public int size(){
        return this.nodeList.size;
    }

    /**
     * 通过KEY取值
     * @param key KEY
     * @return V
     */
    public V get(K key){
        if(key == null){
            return null;
        }

        //1. 获取数据节点
        Node<K, V> data = this.keyNodeMap.get(key);
        if(data == null){
            return null;
        }

        //2. 获取过期时间
        long expired = data.expired;
        long current = System.currentTimeMillis();

        lock.lock();
        try {
            //如果已经过期，则需要删除
            if(expired > 0 && expired <= current){
                this.remove(data);
                return null;
            }

            //3. 将被访问节点移动到链表末尾
            this.nodeList.moveToTail(data);

            //4. 返回结果
            return data.value;
        }finally {
            lock.unlock();
        }
    }

    /**
     * PUT操作，并设置为永不过期
     * @param key KEY
     * @param value VALUE
     */
    public void put(K key, V value){
        this.put(key, value, -1);
    }

    /**
     * PUT操作，并设置过期时间
     * @param key KEY
     * @param value VALUE
     * @param seconds X秒后过期
     */
    public void put(K key, V value, int seconds){
        this.put(key, value, seconds, TimeUnit.SECONDS);
    }

    /**
     * PUT操作，并设置过期时间
     * @param key KEY
     * @param value VALUE
     * @param time 过期时间
     * @param unit 过期单位
     */
    public void put(K key, V value, int time, TimeUnit unit){
        if(key == null || unit == null){
            throw new IllegalArgumentException("参数存在异常！");
        }

        //1. 计算过期时间戳，并创建新节点
        Node<K, V> newNode;
        if(time > 0){
            long expired = System.currentTimeMillis() + unit.toMillis(time);
            newNode = new Node<>(key, value, expired);
        }else{
            newNode = new Node<>(key, value);
        }

        lock.lock();
        try {
            //2. 查询节点是否存在，如果存在则先删除
            Node<K, V> savedNode = this.keyNodeMap.get(key);
            if(savedNode != null){
                this.nodeList.remove(key);
                this.keyNodeMap.remove(key);
            }

            //3. 先执行PUT操作
            this.nodeList.add(newNode);
            this.keyNodeMap.put(key, newNode);

            //4. 如果容量不足，则删除过期和最久数据
            if(this.nodeList.size > capacity){
                //4.1 删除过期数据
                this.removeExpiredData();

                //4.2 如果容量还不足，则删除旧数据
                while (this.nodeList.size > capacity){
                    Node<K, V> headNode = this.nodeList.removeHead();
                    if(headNode != null){
                        this.keyNodeMap.remove(headNode.key);
                    }
                }
            }
        }finally {
            lock.unlock();
        }
    }

    /**
     * 移除指定的KEY
     * @param key KEY
     */
    public void remove(K key){
        if(key == null){
            throw new IllegalArgumentException("参数存在异常！");
        }

        lock.lock();
        try {
            Node<K, V> node = this.keyNodeMap.get(key);
            if(node != null){
                this.remove(node);
            }
        }finally {
            lock.unlock();
        }
    }

    /**
     * 移除指定节点
     * @param node 节点
     */
    private void remove(Node<K, V> node){
        if(node == null){
            throw new IllegalArgumentException("参数存在异常！");
        }

        this.nodeList.remove(node);
        this.keyNodeMap.remove(node.key);
    }

    /**
     * 遍历所有键值对
     * @param action action
     */
    public void forEach(BiConsumer<? super K, ? super V> action){
        this.keyNodeMap.forEach((k, node) -> {
            action.accept(node.key, node.value);
        });
    }

    /**
     * 移除过期数据
     */
    private void removeExpiredData(){
        lock.lock();
        try {
            LinkedList<K> expiredKeyList = this.nodeList.removeExpiredNodes();
            expiredKeyList.forEach(k -> this.keyNodeMap.remove(k));
        }finally {
            lock.unlock();
        }
    }

    /**
     * 初始化定时任务
     */
    private void initCheckScheduledExecutor(){
        if(this.checkScheduledExecutor != null){
            return;
        }

        //1. 定义一个1个线程的定时任务
        this.checkScheduledExecutor = new ScheduledThreadPoolExecutor(1, r -> {
            Thread thread = new Thread(r);
            thread.setDaemon(true);
            thread.setName("check-thread-lru");

            return thread;
        });

        //2. 设置执行频率（2秒钟校验一次）
        this.checkScheduledExecutor.scheduleWithFixedDelay(this::removeExpiredData, 2, 2, TimeUnit.SECONDS);
    }

}